from matplotlib import pyplot as plt
import matplotlib.patches as mpatches

import numpy as np

# Affichage de la table
def AfficheTable(P,cible,poids):
    n = len(poids)
    mat = 0.5*np.ones((n+1,cible+1))
    for i in range(n+1):
        for s in range(cible+1):
            if (i,s) in P:
                mat[i][s] = P[(i,s)]

    plt.close('all')
    fig, ax = plt.subplots()
    ax.imshow(mat, cmap='RdYlGn')

    ax.set_xticks(np.arange(-0.5, cible+1, 1), minor=True)
    ax.set_yticks(np.arange(-0.5, n+1, 1), minor=True)
    ax.grid(which='minor', color='black', linewidth=0.5)

    patch_vrai = mpatches.Patch(color='green', label='Vrai')
    patch_faux = mpatches.Patch(color='red', label='Faux')
    patch_nc = mpatches.Patch(color='khaki', label='Non calculé')
    ax.legend(handles=[patch_vrai, patch_faux, patch_nc],
               loc='upper right',
               fontsize=8,           # Taille du texte
               handlelength=1,       # Largeur des carrés
               handleheight=1,       # Hauteur des carrés
               borderpad=0.3,        # Marge intérieure
               labelspacing=0.3)     # Espacement entre les lignes

    plt.xlabel('Somme cible s')
    plt.ylabel('Nombre d\'éléments i')
    plt.title('Table de programmation dynamique')

    plt.show()


#########################################
# Valeurs optimales - Approche bottom-up
#########################################

poids = [10,10,10,1,5,10]
#poids = [12,53,68,83,81,88]
#poids = [3, 1, 4, 2]
#poids = [10, 10, 10, 1]
#poids = [4, 4, 4]

# Initialisation des données
def initialiser_donnees(poids):
    S = sum(poids)
    S_cible = S//2
    return S, S_cible

# Iinitialisation du dictionnaire P
def initialiser_table(cible):
    P = {(0,0): True}
    for s in range(1,cible+1):
        P[(0,s)] = False
    return P


# Remplissage de la table
def remplir_table(P, poids, cible):
    n = len(poids)
    for i in range(1,n+1):
        for s in range(cible+1):
            # Cas n°1 : On ne prend pas l'élément i
            if poids[i-1] > s:
                P[(i,s)] = P[(i-1,s)]
            # Sinon, cas n°2 : on prend l'élélemnt i
            else:
                P[(i,s)] = P[(i-1,s)] or P[(i-1,s-poids[i-1])]

    return P

# Recherche de la meilleur somme atteignable
def trouver_meilleure_somme_bottomup(P,cible,poids):
    n = len(poids)
    for s in range(cible,-1,-1):
        if P[(n,s)] == True:
            return s


S, S_cible = initialiser_donnees(poids)
P = initialiser_table(S_cible)
P = remplir_table(P,poids,S_cible)
S_opt = trouver_meilleure_somme_bottomup(P,S_cible,poids)
print(S_opt)
AfficheTable(P,S_cible, poids)                          # (cible+1)*(#elements+1) calculs

#########################################
# Valeurs optimales - Approche top-down
#########################################

# Complexité temporelle : O(n·S)
# Mémoire : O(n·S) pour le dictionnaire
#           O(n) pour la pile
#           donc dominé par O(n·S)
def rec_opt_val(i,s):
    # i : nombre d’éléments considérés (les i premiers)
    # s : somme cible à atteindre

    # Utilise la mémoïsation
    if (i,s) in P:
        return P[(i,s)]

    # Cas de base
    if s == 0:
        P[(i,s)] = True
        return P[(i,s)]
    if i == 0:
        P[(i,s)] = False
        return P[(i,s)]

    # Récursion cas n°1 : on ne prend pas l'élélément i
    S1 = rec_opt_val(i-1,s)

    # Cas n°2 : on prend l'élément i (si possible)
    if poids[i-1] <= s:
        S2 = rec_opt_val(i-1,s-poids[i-1])
        resultat = S1 or S2
    else:
        resultat = S1

    # Sauvegarde et retourne la valeur
    P[(i,s)] = resultat

    return P[(i,s)]


def trouver_meilleure_somme_topdown(cible,poids):
    n = len(poids)
    if rec_opt_val(len(poids),cible) == True:
        return cible
    else:
        for s in range(cible,-1,-1):
            if rec_opt_val(len(poids),s) == True:
                return s

P = {}
S_opt = trouver_meilleure_somme_topdown(S_cible,poids)
print(S_opt)
AfficheTable(P,S_cible, poids)


#########################################
# Reconstruction
#########################################

def element_pris(P,poids,i,s):
    if poids[i-1] <= s and P[(i-1,s-poids[i-1])] == True:
        return True
    else:
        return False

# O(n) car n itérations
# avec des opérations en O(1) (acces dictionnaire et comparaison)
def reconstruire_ensemble(P, poids, s_opt):
    elements = []
    n = len(poids)
    for i in range(n,0,-1):
        if element_pris(P,poids,i,s_opt):
            elements.append(i)
            s_opt = s_opt - poids[i-1]
    return elements

camion1 = reconstruire_ensemble(P,poids,S_opt)
print(camion1)

def construire_ensemble2(poids,ensemble1):
    elements = []
    for i in range(len(poids)):
        if (i+1) not in ensemble1:
            elements.append(i+1)

    return elements

camion2 = construire_ensemble2(poids,camion1)
print(camion2)









